OUTPUT_dir <- str_c(TRADING_FIGS)

bid_data <- read_dta(str_c(TRADING_DATA_CLEAN, "panel_plant-period-date-bid.dta"))

CC_sim_data <- read_csv(str_c(MODEL_DATA_OUT, "cc_heateps_output_itk_gpcb_id.csv")) %>%
  filter(id_level == 170000) %>%
  rename(plt_emission_standard = Eit,
         period = id_period) %>%
  select(gpcb_id, period, plt_emission_standard)

SelectPeriod <- 8

selected_plants_list <- c(107825504, 104254492)


clearing_price_selected <- 11.246

example_load_standard <- CC_sim_data %>%
  filter(period==8) %>%
  summarize(plt_emission_standard=mean(plt_emission_standard)) %>% pull(plt_emission_standard)

bid_data_proc <- bid_data %>%
  select(
    plant_period_id,
    gpcb_id,
    period,
    is_auction,
    week_num,
    bid_status,
    bid_day_norm,
    bid_qty,
    bid_id,
    trade_price,
    ln_permit_holding_hyp,
    permit_holding_hyp,
    ln_bid_price,
    bid_price,
    pm_mass_potential_max_12
  ) %>%
  left_join(CC_sim_data) %>%
  group_by(gpcb_id, period) %>%
  mutate(tot_bids = n()) %>%
  ungroup() %>%
  mutate(bid_type = case_when(bid_qty < 0 ~ "Sell",
                              TRUE ~ "Buy")) %>%
  left_join(
    read_dta(str_c(BASELINE_DATA_OUT, "BaselineCovariates.dta")) %>%
      select(gpcb_id, industry_name),
    by = "gpcb_id"
  ) %>%
  group_by(bid_day_norm<=.5, period) %>%
  ungroup() %>%
  mutate(selected_plant = gpcb_id %in% selected_plants_list)

## * Model ####

# Taken from Stata specifying main regression specification in original paper
model_dat <- bid_data_proc %>%
  filter(
    bid_status != "REJ",
    !is.na(bid_qty),
    bid_day_norm <= .5
  ) %>%
  arrange(gpcb_id, period, bid_day_norm, bid_id)

mod_run <- lm(ln_bid_price ~ ln_permit_holding_hyp + factor(plant_period_id) - 1,
              data = model_dat)
mod_result <- summary(mod_run)


## * Model Predictions ####

model_preds <- bid_data_proc %>%
  filter(paste0("factor(plant_period_id)", plant_period_id) %in% rownames(mod_result$coefficients)) %>%
  group_by(plant_period_id) %>%
  slice_head(n = 1) %>%
  ungroup() %>%
  mutate(
    ln_mac_curve_support = map(
      pm_mass_potential_max_12,
      ~ seq(-10, log(.x), length.out = .x / 2)
    ),
    model_hold = mod_result$coefficients[1, 1],
    model_fe = mod_result$coefficients[paste0("factor(plant_period_id)",
                                              as.character(plant_period_id)), 1]
  ) %>%
  unnest(ln_mac_curve_support) %>%
  mutate(mac_pred = exp(ln_mac_curve_support * model_hold + model_fe),
         mac_curve_x = exp(ln_mac_curve_support))

plot_dat <- bid_data_proc %>%
  filter(period==SelectPeriod) %>%
  filter(bid_day_norm<=.5) %>%
  bind_rows(model_preds) %>%
  group_by(gpcb_id, period) %>%
  group_modify(
    function(in_data, labels){
    mac_x <- in_data %>%
      slice_min(abs(mac_pred-clearing_price_selected), with_ties = F) %>%
      pull(mac_curve_x)
    in_data %>%
      mutate(market_emissions = mac_x)
  })


# Calculate areas of gains from trade -------------------------------------
plot_dat %>%  
  filter(gpcb_id %in% selected_plants_list & period==8) %>%
  ungroup() %>%
  slice_tail(n=1, by="gpcb_id") %>% 
  select(gpcb_id, model_hold, model_fe, plt_emission_standard) %>% 
  mutate(
    buy= pmap_dbl(
      list(
        model_hold,
        model_fe,
        plt_emission_standard
        ),
      ~ integrate(function(x){ pmax(0,-clearing_price_selected+exp(..1*log(x)+..2))}, example_load_standard, Inf)$value),
    sell= pmap_dbl(
      list(
        model_hold,
        model_fe,
        plt_emission_standard
      ),
      ~ integrate(function(x){ pmin(0,-clearing_price_selected+exp(..1*log(x)+..2))}, 0, example_load_standard)$value))

plot_dat %>%  
  filter(gpcb_id %in% selected_plants_list & period==8) %>%
  ungroup() %>%
  slice_tail(n=1, by="gpcb_id") %>% 
  select(gpcb_id, model_hold, model_fe, plt_emission_standard) %>% 
  cross_join(tibble(x = seq(0, 1e4, 0.001))) %>% 
  mutate(mac_curve = exp(model_hold*log(x)+model_fe)) %>% 
  filter(abs(mac_curve - clearing_price_selected) < 1e-5)

# Prepare Graphing --------------------------------------------------------

line_arg <- "emissions_line"
scatter_arg <- "emissions_scatter"
alpha_val <- .25
xlab_var <- "Plant emissions (kg)"
ylab_var <- "Bid Price (INR/kg)"
label_perc <- .5
orange_color = viridis_pal(option = "turbo", begin=.2, end=.75)(2)[2]
blue_color = viridis_pal(option = "turbo", begin=.2, end=.75)(2)[1]
red_color = viridis_pal(option = "inferno", begin=.6, end=.75)(1)[1]

plot_skel_maker <- function(legend_title_size= NULL,
                             legend_text_size = NULL,
                             strip_text_size = NULL,
                             axis_text_size= NULL,
                             axis_title_size = NULL,
                             ...){
  plot_dat %>%
    filter(selected_plant == T) %>%
    ggplot()+
    xlab(xlab_var)+
    ylab(ylab_var)+
    theme_classic()+
    theme(legend.title = element_text(size = legend_title_size),
          strip.background = element_blank(),
          strip.text.x = element_text(size = strip_text_size, face="bold"),
          legend.text = element_text(size = legend_text_size),
          axis.text = element_text(size = axis_text_size),
          axis.title = element_text(size = axis_title_size),
          text = element_text(family = "serif"))
}



# Scatter Animate ---------------------------------------------------------

scatter_animate_tibble <- function(plot_skel,
                                   abline_size = NULL,
                                   area_label_size = NULL,
                                   scatter_line_size = NULL,
                                   scatter_load_lbl_ht = NULL,
                                   scatter_point_size = NULL,
                                   abline_label_size,
                                   ...){
  all_scatters<-list()
  
  ## **** Plant Specific Values
  for (plt in selected_plants_list) {
    local_scatters <- list()
    if (plt == 107825504) {
      xmax = 6000
      y_max = 30
      xbreaks = seq(0,xmax, 1000)
      max_clearing_lab_x = 4800
      legend_pos = c(.85,.85)
    } else if (plt== 104254492){
      xmax = 1500
      y_max = 30
      xbreaks = seq(0,xmax, 250)
      max_clearing_lab_x = 1400
      legend_pos = "none"
    }
    
    ## **** General Scatter Graph Setup
    scatter_animate <- plot_skel %+%
      filter(plot_dat,
             gpcb_id==plt,
             period == SelectPeriod)+
      scale_y_continuous(expand = c(0, 0), limits = c(0, y_max))+
      ylab("Bid price (INR/kg)") +
      scale_x_continuous(limits = c(0, xmax),
                         breaks = xbreaks)+
      scale_shape_manual(
        values = c(2, 21), name = "Bid Type",
        guide = guide_legend(
          override.aes = list(alpha = 1, size = scatter_point_size),
          order = 2
        )) +
      theme(legend.position = legend_pos)
    
    ## **** MAC
    
    scatter_animate <- scatter_animate +
      geom_line(aes(
        x = mac_curve_x,
        y = mac_pred,
        group = plant_period_id),
        color = blue_color,
        lwd = scatter_line_size
      )
    
    ## **** MAC + Ld Std + Full Shaded
    scatter_animate <- scatter_animate +
      geom_vline(xintercept=example_load_standard, linetype = 2, lwd = abline_size)+
      geom_text(data= tribble(
        ~gpcb_id, ~x_var, ~y_var, ~hjust,
        107825504, example_load_standard+30, 28, "left",
        104254492, example_load_standard-13, 28, "right",
        # If use plt_emissions_standard
        #255049024, 600,28, "left",
        #544928192, 750,28, "right",
      ) %>%
        filter(gpcb_id==plt),
      aes(x=x_var,
          hjust = hjust),
      y = scatter_load_lbl_ht,
      label="Example load standard",
      #angle=-90,
      size = abline_label_size,
      family = "serif")
    
    scatter_animate <- scatter_animate +
      geom_text(data= tribble(
        ~gpcb_id, ~x_var, ~y_var, ~label_var, ~move_y, ~move_x,
        107825504, 5000, 2, "Variable\nabatement cost", 2, 5000,
        104254492, 1300,2, "Variable\nabatement cost", 2, 1300) %>%
          #255049024, 5500, 4, "Variable\nabatement cost", 8, 5500,
          #544928192, 500,5, "Variable\nabatement cost", 5, 250) %>%
          left_join(plot_dat, by="gpcb_id")%>%
          filter(gpcb_id==plt)%>%
          slice_head(n=1),
        #aes(x=x_var, y=y_var, label=label_var, nudge_y=move_y, nudge_x=move_x),
        aes(x=x_var, y=y_var, label=label_var),
        color= "black",
        segment.size = .4,
        size= area_label_size,
        fontface = 1,
        lineheight=1.05,
        min.segment.length = 0,
        arrow = arrow(length = unit(0.015, "npc")),
        family = "serif")
    
    ## ****MAC + CC + Clearing Price
    
    
    scatter_animate <- scatter_animate +
      geom_hline(yintercept = clearing_price_selected, linetype = 2, lwd = abline_size) +
      annotate("text",
               x = max_clearing_lab_x,
               y = clearing_price_selected+.04*y_max,
               label = "Clearing price",
               size = abline_label_size,
               family = "serif")
    
    scatter_animate <- scatter_animate +
      geom_ribbon(aes(x = mac_curve_x, ymin = 0, ymax = mac_pred),
                  fill= red_color,
                  alpha = .2,
                  data = ~ filter(.x,mac_curve_x >= market_emissions))
    
    ## **** MAC + CC + Clearing Price + Add/Foregone
    ##
    scatter_animate <- scatter_animate +
      geom_ribbon_pattern(aes(x = mac_curve_x, ymin = 0, ymax = mac_pred),
                          fill= red_color,
                          alpha = 0,
                          pattern = "stripe",
                          pattern_alpha = .65,
                          pattern_spacing = .05,
                          pattern_size=.05,
                          pattern_fill = red_color,
                          pattern_color = red_color,
                          data = ~ filter(.x,
                                          mac_curve_x <= pmax(market_emissions,example_load_standard) &
                                            mac_curve_x >= pmin(market_emissions,example_load_standard))
      )
    scatter_animate <- scatter_animate +
      geom_text_repel(data= tribble(
        ~gpcb_id, ~x_var, ~y_var, ~label_var, ~move_y, ~move_x,
        107825504, 1300, 5, "Foregone\nabatement cost\n(cross-hatch)", 5, 250,
        104254492, 500,5, "Added\nabatement cost\n(cross-hatch)", 5, 150) %>%
        # If use plt_emissions_standard
        #255049024, 2000, 8, "Foregone\nabatement cost\n(cross-hatch)", 14, 2700,
        #544928192, 500,5, "Added\nabatement cost\n(cross-hatch)", 5, 250) %>%
          left_join(plot_dat, by="gpcb_id")%>%
          filter(gpcb_id==plt)%>%
          slice_head(n=1),
        aes(x=x_var, y=y_var, label=label_var, nudge_y=move_y, nudge_x=move_x),
        color= "black",
        segment.size = .4,
        size= area_label_size,
        fontface = 1,
        lineheight=1.05,
        min.segment.length = 0,
        arrow = arrow(length = unit(0.015, "npc")),
        family = "serif")
    
    ## **** Add Gains from Trade
    
    scatter_animate <- scatter_animate +
      geom_ribbon(aes(x = mac_curve_x,
                      ymax = pmax(mac_pred,clearing_price_selected),
                      ymin = pmin(mac_pred,clearing_price_selected)),
                  fill= blue_color,
                  alpha = .4,
                  data = ~ filter(.x,
                                  mac_curve_x <= pmax(market_emissions,example_load_standard) &
                                    mac_curve_x >= pmin(market_emissions,example_load_standard))
      )
    scatter_animate <- scatter_animate +
      geom_text_repel(data= ~ tribble(
        ~gpcb_id, ~x_var, ~y_var, ~label_var, ~move_y, ~move_x,
        107825504, 1200, 13.4, "Gains from\nbuying permits", 13.4, 300,
        104254492, 750,9, "Gains from\nselling permits", 14, 700) %>%
          # if use plt_emissions_standard
        # 255049024, 1000, 14, "Gains from\nbuying permits\n(blue)", 19, 1500,
        # 544928192, 625,9, "Gains from\nselling permits\n(blue)", 14, 575) %>%
          filter(gpcb_id %in% .x$gpcb_id),
        aes(x=x_var, y=y_var, label=label_var, nudge_y=move_y, nudge_x=move_x),
        color= "black",
        segment.size = .4,
        size= area_label_size,
        fontface = 1,
        lineheight=1.05,
        min.segment.length = 0,
        arrow = arrow(length = unit(0.015, "npc")),
        family = "serif",
        segment.size = abline_size)
    
    
    local_scatters[["_total"]] = scatter_animate
    
    all_scatters[[as.character(plt)]] = local_scatters
  }
  return(all_scatters)
}

line_plot_tibble <- function(plot_skel,
                             abline_size,
                             abline_label_size,
                             scatter_point_size,
                             ...) {
  plot_skel %+%
    filter(model_preds, period==SelectPeriod) +
    geom_line(aes(x = mac_curve_x, y = mac_pred, group = plant_period_id),
              size = .45,
              color = blue_color,
              alpha = alpha_val) +
    ylab("Marginal abatement cost (INR/kg)") +
    xlim(0, 20000) +
    scale_y_continuous(expand = c(0, .5), limits = c(0, 20))+
    geom_point(aes(x=mac_curve_x, 
                   y=mac_pred,
                   shape = "Non-uniform load\nstandard (one\nsimulation draw)"),
               alpha = .9,
               size=scatter_point_size-1.5,
               color="dodgerblue2", 
               data = ~ group_by(.x, gpcb_id)%>%
                 slice_min(abs(mac_curve_x-plt_emission_standard)))+
    geom_vline(xintercept = example_load_standard, linetype = 2, lwd=abline_size)+
    annotate("text",
             #y=19, x=5000,
             x= 1200, y= 19,
             hjust="left",
             label="Uniform\nload standard",
             size = abline_label_size,
             family = "serif")+
    scale_shape_manual(values = 2)+
    theme(
      legend.background = element_blank(),
      #legend.box.background = element_rect(colour = "black"),
      legend.title = element_blank(),
      legend.position = c(.87,.9),
      legend.text = element_text(size=abline_label_size*2.5)
    )
}


## *** Graph Creator

paper_graph_params <-list(
  save_height_scatter = 90,
  save_width_scatter = 130,
  save_height_line = 90,
  save_width_line = 200,
  scatter_line_size = .8,
  scatter_load_lbl_ht = 28,
  scatter_point_size = 3,
  sel_line_size = 1.1,
  nonsel_line_size = .35,
  line_label_size = 9,
  area_label_size = 3.2,
  abline_size = .5,
  abline_label_size= 3.4,
  line_abline_size = 5,
  legend_title_size = 15,
  legend_text_size = 12,
  strip_text_size = 0,
  axis_text_size = 10,
  axis_title_size = 10
)

## **** Save graphs

graph_tibble0 <- as_tibble_row(paper_graph_params) %>%
  mutate(plot_skel = pmap(., plot_skel_maker))

graph_tibble_line_plot_saved <- graph_tibble0 %>%
  mutate(line_plot = pmap(., line_plot_tibble),
         line_plot_filename = paste0(OUTPUT_dir, "Figure_7.pdf"),
         saving = pmap(list(line_plot_filename,
                            line_plot,
                            save_height_line,
                            save_width_line),
                       ~ ggsave(filename = ..1,
                                plot=..2,
                                height = ..3,
                                width=..4,
                                units="mm"))) 

graph_tibble_line_plot_saved %>%
  select(-c(line_plot, saving)) %>%
  mutate(animated_scatter_plots = pmap(., scatter_animate_tibble)) %>%
  unnest_longer(animated_scatter_plots)%>%
  rename(gpcb_id=animated_scatter_plots_id)%>%
  unnest_longer(animated_scatter_plots)%>%
  rename(graph_type=animated_scatter_plots_id) %>%
  mutate(scatter_plot_filename = 
           str_c(
             OUTPUT_dir, 
             case_when(
               gpcb_id == "104254492" ~ "Figure_9_A", 
               gpcb_id == "107825504" ~ "Figure_9_B", 
               T ~ paste0(gpcb_id, graph_type)
               ), ".pdf"),
         saving = pmap(list(scatter_plot_filename,
                            animated_scatter_plots,
                            save_height_scatter,
                            save_width_scatter),
                       ~ ggsave(filename = ..1,
                                plot=..2,
                                height = ..3,
                                width=..4,
                                units="mm")))
                                
                                
